#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Aug 27 10:12:30 2020

@author: Jonah
"""
from IPython import get_ipython
#get_ipython().run_line_magic('matplotlib', 'inline')
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.svm import SVC # "Support vector classifier"
from sklearn.svm import LinearSVC # "Support vector classifier"
from sklearn.metrics import roc_curve, auc
from sklearn.model_selection import LeaveOneOut
from sklearn.preprocessing import label_binarize
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import average_precision_score
from sklearn.metrics import f1_score
from sklearn.metrics import confusion_matrix
import seaborn as sns
from matplotlib import pyplot

####Training model on 162/187 samples
Training = pd.read_csv(r'Luminex_data_training_set_for_LOOCV.csv') # <---insert path to this file which was generated by the included Rscript

Training_frame = Training.values[:,2:]
Training_frame = Training_frame.astype(float)
Training_OA = Training_frame[:,-1]
Training_MMPs_cytokines = Training_frame[:,:18]

loo = LeaveOneOut()
loo.get_n_splits(Training_MMPs_cytokines)

#Finds the optimal C regularization term for the SVM
# =============================================================================
rng_C=np.logspace(-5, 5, 11)
best_score=0
for i in rng_C:
    OA_model = LinearSVC(penalty='l1',C=i,dual=False,max_iter=10000000)
    y_scores = []
    y_predict =[]
    rocs = []
    for train_index, test_index in loo.split(Training_MMPs_cytokines):
        X_train, X_test = Training_MMPs_cytokines[train_index], Training_MMPs_cytokines[test_index]
        y_train, y_test = Training_OA[train_index], Training_OA[test_index]
        y_score = OA_model.fit(X_train, y_train).decision_function(X_test)
        y_scores.append(y_score)
    fpr, tpr, _ = roc_curve(Training_OA, y_scores, pos_label=1)
    temp_auc = auc(fpr, tpr)
    if (temp_auc > best_score):
        best_score = temp_auc
        C_best = i
print(best_score, C_best)
# =============================================================================

#OA_model = SVC(kernel='linear', C=0.0001, class_weight="balanced")
OA_model = LinearSVC(penalty='l1',C=0.1,dual=False,max_iter=10000000)

y_scores = []
y_predict =[]
rocs = []
for train_index, test_index in loo.split(Training_MMPs_cytokines):
    print("TEST:", test_index)
    X_train, X_test = Training_MMPs_cytokines[train_index], Training_MMPs_cytokines[test_index]
    y_train, y_test = Training_OA[train_index], Training_OA[test_index]
    print(X_train, X_test, y_train, y_test)
    y_score = OA_model.fit(X_train, y_train).decision_function(X_test)
    y_scores.append(y_score)
    y_predict.append(OA_model.predict(X_test))
    rocs.append(roc_curve(y_test, y_score))

OA_model.coef_
    
#Confusion matrix    
cf_matrix = confusion_matrix(Training_OA, y_predict)
group_names = ['True Neg','False Pos','False Neg','True Pos']
group_counts = ["{0:0.0f}".format(value) for value in
                cf_matrix.flatten()]
group_percentages = ["{0:.2%}".format(value) for value in
                     cf_matrix.flatten()/np.sum(cf_matrix)]
labels = [f"{v1}\n{v2}\n{v3}" for v1, v2, v3 in
          zip(group_names,group_counts,group_percentages)]
labels = np.asarray(labels).reshape(2,2)
sns.heatmap(cf_matrix, annot=labels, fmt='', cmap='Blues')
print("Accuracy (%correct): ", ((53+14)/(53+14+6+18)))
print("Precision (PPV): ", (14/(14+6)))
print("Recall (Sensitivity): ", (14/(14+18)))
print("NPV: ", (53/(53+18)))
print("Specificity: ", (53/(53+6)))
    
from collections import Counter
Counter(Training_OA) # y_true must be your labels
    
fpr, tpr, _ = roc_curve(Training_OA, y_scores, pos_label=1)
roc_auc = auc(fpr, tpr)
OA_model.coef_
# calculate precision-recall curve
precision, recall, thresholds = precision_recall_curve(Training_OA, y_scores)
# calculate F1 score: harmonic mean of the precision and recall
f1 = f1_score(Training_OA, y_predict)
# calculate precision-recall AUC
prc_auc = auc(recall, precision)
    
# Plot of a ROC curve for a specific class
plt.figure()
plt.plot(fpr, tpr, label='PR curve (area = %0.2f)' % roc_auc)
plt.plot([0, 1], [0, 1], 'k--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic')
plt.legend(loc="lower right")
plt.show()
plt.figure()
pyplot.plot(recall,precision)
pyplot.xlabel('Recall')
pyplot.ylabel('Precision')
plt.title('Precision Recall')
pyplot.show()

####Testing the model with 56 samples left out
Testing = pd.read_csv(r'Luminex_data_testing_set_for_LOOCV.csv') # <--- same as above, match to your directory
Testing_frame = Testing.values[:,2:]
Testing_frame = Testing_frame.astype(float)
Testing_OA = Testing_frame[:,-1]
Testing_MMPs_cytokines = Testing_frame[:,:33]

from collections import Counter
Counter(Testing_OA) # y_true must be your labels

testing_scores = OA_model.fit(Training_MMPs_cytokines,Training_OA).decision_function(Testing_MMPs_cytokines)
testing_predictions = OA_model.predict(Testing_MMPs_cytokines)
fpr_test, tpr_test, _ = roc_curve(Testing_OA, testing_scores, pos_label=1)
roc_auc_test = auc(fpr_test, tpr_test)
plt.figure()
plt.plot(fpr_test, tpr_test, label='ROC curve (area = 0.48)' % roc_auc_test)
plt.plot([0, 1], [0, 1], 'k--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic')
plt.legend(loc="lower right")
plt.show()

# calculate precision-recall curve
precision_test, recall_test, thresholds_test = precision_recall_curve(Testing_OA, testing_scores)
# calculate F1 score: harmonic mean of the precision and recall
f1_test = f1_score(Testing_OA, testing_predictions)
# calculate precision-recall AUC
prc_auc_test = auc(recall_test, precision_test)

plt.figure()
pyplot.plot(recall_test,precision_test)
pyplot.xlabel('Recall')
pyplot.ylabel('Precision')
plt.title('Precision Recall')
pyplot.show()

#Confusion matrix for testing   
cf_matrix_test = confusion_matrix(Testing_OA, testing_predictions)
group_names = ['True Neg','False Pos','False Neg','True Pos']
group_counts = ["{0:0.0f}".format(value) for value in
                cf_matrix_test.flatten()]
group_percentages = ["{0:.2%}".format(value) for value in
                     cf_matrix_test.flatten()/np.sum(cf_matrix_test)]
labels = [f"{v1}\n{v2}\n{v3}" for v1, v2, v3 in
          zip(group_names,group_counts,group_percentages)]
labels = np.asarray(labels).reshape(2,2)
sns.heatmap(cf_matrix_test, annot=labels, fmt='', cmap='Blues')
print("Accuracy (%correct): ", ((17+6)/(17+6+8+7)))
print("Precision (PPV): ", (6/(6+8)))
print("Recall (Sensitivity): ", (6/(6+7)))
print("NPV: ", (17/(17+7)))
print("Specificity: ", (17/(17+8)))

import sys
print(sys.version)